Skip to content

Conversation

abheesht17
Copy link
Collaborator

@abheesht17 abheesht17 commented Jul 13, 2022

Resolves #241

Partially resolves #277

  • Greedy Search
  • Beam Search (will probably open a separate PR for this)
  • Top-p Search
  • Top-k Search
  • Random Search

Will have to think a bit more about Beam Search.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Left a few comments


expected_outputs = tf.tile([[3], [0]], [1, max_length - 2])
expected_outputs = tf.concat([inputs, expected_outputs], axis=1)
self.assertEqual(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would just using mode.predict work? that would still hit all the compiled function paths, and allow you to avoid all this dummy metric stuff, which is hard to read

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could also be good to test the call on a batched dataset (where batch size not statically known), and on a single constant input, as you are doing here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, man. Stupid me. Should have used model.predict :P

),
body=one_step,
loop_vars=[prompt],
shape_invariants=[tf.TensorShape(shape_invariants)],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just pass tf.TensorShape([None, None]) as the shape invariant? Generally we should support a static batch size of None, tf data does this by default after calling .batch() for example. Might simplify the code a bit.


inputs = tf.constant([[0, 1], [1, 2]])
model = TestModel()
model.compile(metrics=[dummy_metric])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you add jit_compile=True does the test still pass?

If yes, we should test this with both jit_compile=True and False, using https://docs.pytest.org/en/6.2.x/parametrize.html

If no, we should either try to fix things with jit_compilation, or make sure we track that on a follow up issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abheesht17 Let's try add a test case for jit_compile=True, and we can run it on GPU. We recently add GPU test support in this repo.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not working with jit_compile = True. Complete error logs: https://p.ip.fi/2TNt.

Looks like it won't work with shape_invariants.

@mattdangerw
Copy link
Member

Also re-beam search, separate PR sounds good!

Copy link
Collaborator Author

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattdangerw, thanks for the review! Addressed all comments, save the jit_compile one.


expected_outputs = tf.tile([[3], [0]], [1, max_length - 2])
expected_outputs = tf.concat([inputs, expected_outputs], axis=1)
self.assertEqual(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, man. Stupid me. Should have used model.predict :P


inputs = tf.constant([[0, 1], [1, 2]])
model = TestModel()
model.compile(metrics=[dummy_metric])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not working with jit_compile = True. Complete error logs: https://p.ip.fi/2TNt.

Looks like it won't work with shape_invariants.

@mattdangerw
Copy link
Member

/gcbrun

@mattdangerw
Copy link
Member

I think a pull request went by recently where we stopped doing seeded random generation because of discrepancies.

#269

Is this safe to land as is @chenmoneygithub @jessechancy ?

@jessechancy
Copy link
Contributor

Seeded random generation should be removed. This is mainly because even when fully seeded, the randomness output is different on accelerator-testing with GPU.

@abheesht17 abheesht17 changed the title Make Decoding Functions Graph-compatible Make Decoding Functions Graph-compatible (with XLA Support!) Aug 10, 2022
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. Just leaving some quick initial comments.

tf.cast(max_length, dtype=tf.int64),
),
body=one_step,
loop_vars=[state],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we avoid the state dict and just do loops_vars=(length, prompt) here? might be a little more readable


# Pad the prompt with `pad_token_id` to `max_length`. We use `map_fn` here
# because the batch_size might not be static.
prompt = tf.map_fn(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we should be able to make this simpler, we are just padding a batched tensor with pad_token_id to the sequence length right? We should not need a map_fn for this

loop_vars=[state],
)[0]

prompt = state["prompt"]
if end_token_id is not None:
prompt = mask_tokens_after_end_token(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we even need this function anymore, if we are just starting with the correct sized tensor filled with pad_token_id?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I guess we do to avoid random tokens after the end_token_id

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments, but this looks pretty good to me! I only commented on one of the four utilities, but comments apply to all.

length = prompt.shape.as_list()[1]

# Pad the prompt with `pad_token_id` to `max_length`.
prompt = tf.concat(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe split this into two lines for readability?

padding = tf.fill((tf.shape(prompt)[0], max_length - length), pad_token_id)
prompt = tf.concat((prompt, padding), axis=-1)

while i < max_length:
# If the prompt has reached our desired length, exit while loop.
pred = token_probability_fn(prompt)
length = prompt.shape.as_list()[1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just do something like

batch_size, length = tf.shape(x)

And use that below? Then length and batch size are both tensors from the start.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I'll split this into two lines:

batch_size = tf.shape(prompt)[0]
length = tf.shape(prompt)[1]
batch_size, length = tf.shape(x)

does not work in graph mode.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stack trace: https://p.ip.fi/6YAg

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah destructuring is too fancy for autograph, I forgot. let's do

shape = tf.shape(prompt)
batch_size = shape[0]
length = shape[1]

return (length, prompt)

# Run a while loop till text of length `max_length` has been generated.
prompt = tf.while_loop(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

length, prompt = tf.while_loop(...)

just to avoid that [1] which is not super readable


class TestModel(tf.keras.Model):
def call(self, inputs, training=False):
if not training:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason you have to do training switch here? it looks like you are never actually testing the training=True branch, might be nice to clean up the test a bit

@mattdangerw
Copy link
Member

@chenmoneygithub do you know why the accelerator testing is failing here? This would be a great one to actually test on accelerators.

@chenmoneygithub
Copy link
Contributor

I found it out, it's because the git branch has not synced to master branch, so the build file is outdated.

@abheesht17 Could you sync and push again? Thanks!

@abheesht17
Copy link
Collaborator Author

I found it out, it's because the git branch has not synced to master branch, so the build file is outdated.

@abheesht17 Could you sync and push again? Thanks!

Sure!

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks!

Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! Dropped a comment on the test.

Also could you help create a TODO(chenmoneygithub) at the top of text_generation.py saying we should refactor the code to reuse the same code? The padding + scatter_update handling is more complex than before, so it might be nice we can reuse the code.

@@ -342,7 +406,7 @@ def test_generate_with_ragged_prompt(self):
def test_assert_probability_distribution_generation_is_correct(self):
def token_probability_fn(inputs):
batch_size = inputs.shape[0]
prob = tf.constant([[0.01, 0.01, 0.08, 0.9]])
prob = tf.constant([[0.0, 0.0, 0.0, 1.0]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to change the number here? The original value seems to be more general?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes. This was done to take care of accelerator testing. Seeded generation does not work, so, we've made the probability 1.

@chenmoneygithub chenmoneygithub merged commit 34c0e27 into keras-team:master Aug 16, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Decoding Functions Not Working when jit_compile = True Make Decoding Functions Graph-compatible
4 participants